{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab4d0442",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = ''"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e9d6261b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import (\n",
    "    AutoTokenizer,\n",
    "    T5ForConditionalGeneration,\n",
    "    AutoConfig,\n",
    ")\n",
    "from pytorch_lightning import LightningModule\n",
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d426b62d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained('./out-small-1.1')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "94be4a10",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Module(LightningModule):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        config = AutoConfig.from_pretrained(\n",
    "            './out-small-1.1'\n",
    "        )\n",
    "        self.model = T5ForConditionalGeneration.from_pretrained(\n",
    "            './out-small-1.1',\n",
    "            config=config,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e4b5d292",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "total 2.7G\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 917M Oct 26 23:09 'model-epoch=04-step=42000.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 917M Oct 26 23:35 'model-epoch=04-step=44000.ckpt'\r\n",
      "-rwxrwxrwx 1 ubuntu ubuntu 917M Oct 27 00:01 'model-epoch=04-step=46000.ckpt'\r\n"
     ]
    }
   ],
   "source": [
    "!ls -lh logs/small"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "74be228f",
   "metadata": {},
   "outputs": [],
   "source": [
    "model = Module()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "fc74cda8",
   "metadata": {},
   "outputs": [],
   "source": [
    "weights = model.state_dict()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "cd9d6272",
   "metadata": {},
   "outputs": [],
   "source": [
    "old_weights = torch.load('logs/small/model-epoch=04-step=46000.ckpt',\n",
    "                             map_location=torch.device('cpu'))['state_dict'].items()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "652becba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "model.shared.weight model.shared.weight\n",
      "model.encoder.embed_tokens.weight model.encoder.embed_tokens.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.q.weight model.encoder.block.0.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.k.weight model.encoder.block.0.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.v.weight model.encoder.block.0.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.o.weight model.encoder.block.0.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
      "model.encoder.block.0.layer.0.layer_norm.weight model.encoder.block.0.layer.0.layer_norm.weight\n",
      "model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.0.layer.1.DenseReluDense.wo.weight model.encoder.block.0.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.0.layer.1.layer_norm.weight model.encoder.block.0.layer.1.layer_norm.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.q.weight model.encoder.block.1.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.k.weight model.encoder.block.1.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.v.weight model.encoder.block.1.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.1.layer.0.SelfAttention.o.weight model.encoder.block.1.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.1.layer.0.layer_norm.weight model.encoder.block.1.layer.0.layer_norm.weight\n",
      "model.encoder.block.1.layer.1.DenseReluDense.wi_0.weight model.encoder.block.1.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.1.layer.1.DenseReluDense.wi_1.weight model.encoder.block.1.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.1.layer.1.DenseReluDense.wo.weight model.encoder.block.1.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.1.layer.1.layer_norm.weight model.encoder.block.1.layer.1.layer_norm.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.q.weight model.encoder.block.2.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.k.weight model.encoder.block.2.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.v.weight model.encoder.block.2.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.2.layer.0.SelfAttention.o.weight model.encoder.block.2.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.2.layer.0.layer_norm.weight model.encoder.block.2.layer.0.layer_norm.weight\n",
      "model.encoder.block.2.layer.1.DenseReluDense.wi_0.weight model.encoder.block.2.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.2.layer.1.DenseReluDense.wi_1.weight model.encoder.block.2.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.2.layer.1.DenseReluDense.wo.weight model.encoder.block.2.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.2.layer.1.layer_norm.weight model.encoder.block.2.layer.1.layer_norm.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.q.weight model.encoder.block.3.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.k.weight model.encoder.block.3.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.v.weight model.encoder.block.3.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.3.layer.0.SelfAttention.o.weight model.encoder.block.3.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.3.layer.0.layer_norm.weight model.encoder.block.3.layer.0.layer_norm.weight\n",
      "model.encoder.block.3.layer.1.DenseReluDense.wi_0.weight model.encoder.block.3.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.3.layer.1.DenseReluDense.wi_1.weight model.encoder.block.3.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.3.layer.1.DenseReluDense.wo.weight model.encoder.block.3.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.3.layer.1.layer_norm.weight model.encoder.block.3.layer.1.layer_norm.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.q.weight model.encoder.block.4.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.k.weight model.encoder.block.4.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.v.weight model.encoder.block.4.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.4.layer.0.SelfAttention.o.weight model.encoder.block.4.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.4.layer.0.layer_norm.weight model.encoder.block.4.layer.0.layer_norm.weight\n",
      "model.encoder.block.4.layer.1.DenseReluDense.wi_0.weight model.encoder.block.4.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.4.layer.1.DenseReluDense.wi_1.weight model.encoder.block.4.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.4.layer.1.DenseReluDense.wo.weight model.encoder.block.4.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.4.layer.1.layer_norm.weight model.encoder.block.4.layer.1.layer_norm.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.q.weight model.encoder.block.5.layer.0.SelfAttention.q.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.k.weight model.encoder.block.5.layer.0.SelfAttention.k.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.v.weight model.encoder.block.5.layer.0.SelfAttention.v.weight\n",
      "model.encoder.block.5.layer.0.SelfAttention.o.weight model.encoder.block.5.layer.0.SelfAttention.o.weight\n",
      "model.encoder.block.5.layer.0.layer_norm.weight model.encoder.block.5.layer.0.layer_norm.weight\n",
      "model.encoder.block.5.layer.1.DenseReluDense.wi_0.weight model.encoder.block.5.layer.1.DenseReluDense.wi_0.weight\n",
      "model.encoder.block.5.layer.1.DenseReluDense.wi_1.weight model.encoder.block.5.layer.1.DenseReluDense.wi_1.weight\n",
      "model.encoder.block.5.layer.1.DenseReluDense.wo.weight model.encoder.block.5.layer.1.DenseReluDense.wo.weight\n",
      "model.encoder.block.5.layer.1.layer_norm.weight model.encoder.block.5.layer.1.layer_norm.weight\n",
      "model.encoder.final_layer_norm.weight model.encoder.final_layer_norm.weight\n",
      "model.decoder.embed_tokens.weight model.decoder.embed_tokens.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.q.weight model.decoder.block.0.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.k.weight model.decoder.block.0.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.v.weight model.decoder.block.0.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.o.weight model.decoder.block.0.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.decoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight\n",
      "model.decoder.block.0.layer.0.layer_norm.weight model.decoder.block.0.layer.0.layer_norm.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.q.weight model.decoder.block.0.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.k.weight model.decoder.block.0.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.v.weight model.decoder.block.0.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.0.layer.1.EncDecAttention.o.weight model.decoder.block.0.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.0.layer.1.layer_norm.weight model.decoder.block.0.layer.1.layer_norm.weight\n",
      "model.decoder.block.0.layer.2.DenseReluDense.wi_0.weight model.decoder.block.0.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.0.layer.2.DenseReluDense.wi_1.weight model.decoder.block.0.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.0.layer.2.DenseReluDense.wo.weight model.decoder.block.0.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.0.layer.2.layer_norm.weight model.decoder.block.0.layer.2.layer_norm.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.q.weight model.decoder.block.1.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.k.weight model.decoder.block.1.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.v.weight model.decoder.block.1.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.1.layer.0.SelfAttention.o.weight model.decoder.block.1.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.1.layer.0.layer_norm.weight model.decoder.block.1.layer.0.layer_norm.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.q.weight model.decoder.block.1.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.k.weight model.decoder.block.1.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.v.weight model.decoder.block.1.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.1.layer.1.EncDecAttention.o.weight model.decoder.block.1.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.1.layer.1.layer_norm.weight model.decoder.block.1.layer.1.layer_norm.weight\n",
      "model.decoder.block.1.layer.2.DenseReluDense.wi_0.weight model.decoder.block.1.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.1.layer.2.DenseReluDense.wi_1.weight model.decoder.block.1.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.1.layer.2.DenseReluDense.wo.weight model.decoder.block.1.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.1.layer.2.layer_norm.weight model.decoder.block.1.layer.2.layer_norm.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.q.weight model.decoder.block.2.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.k.weight model.decoder.block.2.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.v.weight model.decoder.block.2.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.2.layer.0.SelfAttention.o.weight model.decoder.block.2.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.2.layer.0.layer_norm.weight model.decoder.block.2.layer.0.layer_norm.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.q.weight model.decoder.block.2.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.k.weight model.decoder.block.2.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.v.weight model.decoder.block.2.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.2.layer.1.EncDecAttention.o.weight model.decoder.block.2.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.2.layer.1.layer_norm.weight model.decoder.block.2.layer.1.layer_norm.weight\n",
      "model.decoder.block.2.layer.2.DenseReluDense.wi_0.weight model.decoder.block.2.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.2.layer.2.DenseReluDense.wi_1.weight model.decoder.block.2.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.2.layer.2.DenseReluDense.wo.weight model.decoder.block.2.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.2.layer.2.layer_norm.weight model.decoder.block.2.layer.2.layer_norm.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.q.weight model.decoder.block.3.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.k.weight model.decoder.block.3.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.v.weight model.decoder.block.3.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.3.layer.0.SelfAttention.o.weight model.decoder.block.3.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.3.layer.0.layer_norm.weight model.decoder.block.3.layer.0.layer_norm.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.q.weight model.decoder.block.3.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.k.weight model.decoder.block.3.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.v.weight model.decoder.block.3.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.3.layer.1.EncDecAttention.o.weight model.decoder.block.3.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.3.layer.1.layer_norm.weight model.decoder.block.3.layer.1.layer_norm.weight\n",
      "model.decoder.block.3.layer.2.DenseReluDense.wi_0.weight model.decoder.block.3.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.3.layer.2.DenseReluDense.wi_1.weight model.decoder.block.3.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.3.layer.2.DenseReluDense.wo.weight model.decoder.block.3.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.3.layer.2.layer_norm.weight model.decoder.block.3.layer.2.layer_norm.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.q.weight model.decoder.block.4.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.k.weight model.decoder.block.4.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.v.weight model.decoder.block.4.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.4.layer.0.SelfAttention.o.weight model.decoder.block.4.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.4.layer.0.layer_norm.weight model.decoder.block.4.layer.0.layer_norm.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.q.weight model.decoder.block.4.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.k.weight model.decoder.block.4.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.v.weight model.decoder.block.4.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.4.layer.1.EncDecAttention.o.weight model.decoder.block.4.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.4.layer.1.layer_norm.weight model.decoder.block.4.layer.1.layer_norm.weight\n",
      "model.decoder.block.4.layer.2.DenseReluDense.wi_0.weight model.decoder.block.4.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.4.layer.2.DenseReluDense.wi_1.weight model.decoder.block.4.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.4.layer.2.DenseReluDense.wo.weight model.decoder.block.4.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.4.layer.2.layer_norm.weight model.decoder.block.4.layer.2.layer_norm.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.q.weight model.decoder.block.5.layer.0.SelfAttention.q.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.k.weight model.decoder.block.5.layer.0.SelfAttention.k.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.v.weight model.decoder.block.5.layer.0.SelfAttention.v.weight\n",
      "model.decoder.block.5.layer.0.SelfAttention.o.weight model.decoder.block.5.layer.0.SelfAttention.o.weight\n",
      "model.decoder.block.5.layer.0.layer_norm.weight model.decoder.block.5.layer.0.layer_norm.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.q.weight model.decoder.block.5.layer.1.EncDecAttention.q.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.k.weight model.decoder.block.5.layer.1.EncDecAttention.k.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.v.weight model.decoder.block.5.layer.1.EncDecAttention.v.weight\n",
      "model.decoder.block.5.layer.1.EncDecAttention.o.weight model.decoder.block.5.layer.1.EncDecAttention.o.weight\n",
      "model.decoder.block.5.layer.1.layer_norm.weight model.decoder.block.5.layer.1.layer_norm.weight\n",
      "model.decoder.block.5.layer.2.DenseReluDense.wi_0.weight model.decoder.block.5.layer.2.DenseReluDense.wi_0.weight\n",
      "model.decoder.block.5.layer.2.DenseReluDense.wi_1.weight model.decoder.block.5.layer.2.DenseReluDense.wi_1.weight\n",
      "model.decoder.block.5.layer.2.DenseReluDense.wo.weight model.decoder.block.5.layer.2.DenseReluDense.wo.weight\n",
      "model.decoder.block.5.layer.2.layer_norm.weight model.decoder.block.5.layer.2.layer_norm.weight\n",
      "model.decoder.final_layer_norm.weight model.decoder.final_layer_norm.weight\n",
      "model.lm_head.weight model.lm_head.weight\n"
     ]
    }
   ],
   "source": [
    "for k, v in old_weights:\n",
    "    new_k = k.replace('._orig_mod', '')\n",
    "    print(k, new_k)\n",
    "    weights[new_k] = v"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "839ae22f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.load_state_dict(weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "c698e532",
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils.copied_utils import (\n",
    "    compute_input_and_target_lengths,\n",
    "    DataCollatorForT5MLM,\n",
    "    tokenize_function,\n",
    "    DataCollatorForNI,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "60e223c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(568, 114)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "before_mask_input_length, target_length = compute_input_and_target_lengths(\n",
    "    inputs_length=512,\n",
    "    noise_density=0.15,\n",
    "    mean_noise_span_length=3.0,\n",
    ")\n",
    "before_mask_input_length, target_length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "e78b0fbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "from streaming.base.format.mds.encodings import Encoding, _encodings\n",
    "from streaming import StreamingDataset\n",
    "import torch\n",
    "import numpy as np\n",
    "\n",
    "class Int32(Encoding):\n",
    "    def encode(self, obj) -> bytes:\n",
    "        return obj.tobytes()\n",
    "\n",
    "    def decode(self, data: bytes):\n",
    "        return np.frombuffer(data, np.int32)\n",
    "\n",
    "\n",
    "_encodings['int32'] = Int32\n",
    "\n",
    "\n",
    "class DatasetFixed(torch.utils.data.Dataset):\n",
    "    def __init__(self, local):\n",
    "        self.dataset = StreamingDataset(local=local)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        data = self.dataset[idx]\n",
    "        data.pop('token_type_ids', None)\n",
    "        for k in data.keys():\n",
    "            data[k] = data[k].astype(np.int64)\n",
    "        return data\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "d3c2bef2",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = DatasetFixed(local='/home/ubuntu/nanot5-512')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "317b0cda",
   "metadata": {},
   "outputs": [],
   "source": [
    "data_collator = DataCollatorForT5MLM(\n",
    "    tokenizer=tokenizer,\n",
    "    noise_density=0.15,\n",
    "    mean_noise_span_length=3.0,\n",
    "    input_length=512,\n",
    "    target_length=target_length,\n",
    "    pad_token_id=0,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "4028b237",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch = [dataset[i] for i in range(5000000, 5000001, 1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d788fba7",
   "metadata": {},
   "outputs": [],
   "source": [
    "padded = data_collator(batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "ad3a2d8e",
   "metadata": {},
   "outputs": [],
   "source": [
    "o = model.model(**padded)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "ba335b10",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[-0.1583, -1.0797,  2.4659,  ..., -0.3984, -0.3889, -0.9530],\n",
       "         [-0.9148, -3.7254,  8.1551,  ..., -1.9462, -2.2352, -2.9646],\n",
       "         [-0.3222, -1.2524,  9.8471,  ..., -1.1862, -1.5322, -4.7910],\n",
       "         ...,\n",
       "         [-0.9169,  1.2961,  6.9794,  ..., -0.5234,  0.4043, -0.7112],\n",
       "         [-1.0103, -1.0219,  8.2162,  ..., -2.5970, -2.2210, -3.6502],\n",
       "         [-6.6296,  2.9095, 40.6058,  ..., -4.4431, -2.1524, -3.0650]]],\n",
       "       grad_fn=<UnsafeViewBackward0>)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "o.logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "e29488ad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<extra_id_40>. Boleh pi IG Haliza May<extra_id_59> dia tanya<extra_id_38> tak ramai<extra_id_18> dia balik<extra_id_46></s><s><extra_id_66>28 PM<extra_id_71> sket<extra_id_67>wira konon tp<extra_id_6> makcik bawang.<extra_id_53>, konfiden giler<extra_id_10></s><extra_id_61>arkan<extra_id_2> penulis yang telah<extra_id_84>rah dan menerima ketentuan Il<extra_id_8> berpisah ketika masih hidup, sakitnya<extra_id_28> akhirnya<extra_id_70> memisahkan. Banyak<extra_id_76> p<extra_id_62>atapan<extra_id_54>?? Dia pun<extra_id_12></s><s> akak kot<extra_id_0> tgu...tah br<extra_id_75> kalau die betul<extra_id_85> die<extra_id_83> Skuau at 26<extra_id_52> PM<extra_id_89>....<extra_id_77> ha...mkn la</s>'"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "padded['labels']\n",
    "tokenizer.decode(padded['labels'][0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "3206b0ff",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'<extra_id_40>. Pas<extra_id_59> Mah<extra_id_59><extra_id_59> Mah<extra_id_59> dia<extra_id_38> dia siapa<extra_id_18><extra_id_18> dia balik<extra_id_46>.<s><extra_id_66>39 PM<extra_id_71>..rg..giira<extra_id_6><extra_id_6><extra_id_6>...<extra_id_53>..<extra_id_10>ekem<extra_id_10><extra_id_10>er<extra_id_10></s><extra_id_61>yaratkan<extra_id_2> kisah<extra_id_84><extra_id_84><extra_id_84>rah dengan kembali ketentuan Il<extra_id_8> tidak<extra_id_28> itu<extra_id_28>,<extra_id_28><extra_id_28><extra_id_28> akan<extra_id_70> berlaku. Banyak<extra_id_76> p<extra_id_62>edaapan<extra_id_54>...</s> pun<extra_id_12></s><s> h<extra_id_0><extra_id_0> tgu...<extra_id_75><extra_id_75><extra_id_75> kalau betul betul<extra_id_85> die<extra_id_83> muau at 24<extra_id_52> PM<extra_id_89>.<extra_id_77> yg...</s>cm x</s>'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(o.logits.argmax(-1).detach().cpu().numpy()[0])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
